{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import collections\n",
    "import csv\n",
    "import os\n",
    "import modeling\n",
    "import optimization\n",
    "import tokenization\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class InputExample(object):\n",
    "    \"\"\"A single training/test example for simple sequence classification.\"\"\"\n",
    "\n",
    "    def __init__(self, guid, text_a, text_b=None, label=None):\n",
    "        \"\"\"Constructs a InputExample.\n",
    "\n",
    "        Args:\n",
    "          guid: Unique id for the example.\n",
    "          text_a: string. The untokenized text of the first sequence. For single\n",
    "            sequence tasks, only this sequence must be specified.\n",
    "          text_b: (Optional) string. The untokenized text of the second sequence.\n",
    "            Only must be specified for sequence pair tasks.\n",
    "          label: (Optional) string. The label of the example. This should be\n",
    "            specified for train and dev examples, but not for test examples.\n",
    "        \"\"\"\n",
    "        \n",
    "        self.guid = guid\n",
    "        self.text_a = text_a\n",
    "        self.text_b = text_b\n",
    "        self.label = label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class InputFeatures(object):\n",
    "    \"\"\"A single set of features of data.\"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "               input_ids,\n",
    "               input_mask,\n",
    "               segment_ids,\n",
    "               label_id,\n",
    "               is_real_example=True):\n",
    "        self.input_ids = input_ids\n",
    "        self.input_mask = input_mask\n",
    "        self.segment_ids = segment_ids\n",
    "        self.label_id = label_id\n",
    "        self.is_real_example = is_real_example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataProcessor(object):\n",
    "    \"\"\"Base class for data converters for sequence classification data sets.\"\"\"\n",
    "\n",
    "    def get_train_examples(self, data_dir):\n",
    "        \"\"\"Gets a collection of `InputExample`s for the train set.\"\"\"\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    def get_dev_examples(self, data_dir):\n",
    "        \"\"\"Gets a collection of `InputExample`s for the dev set.\"\"\"\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    def get_test_examples(self, data_dir):\n",
    "        \"\"\"Gets a collection of `InputExample`s for prediction.\"\"\"\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    def get_labels(self):\n",
    "        \"\"\"Gets the list of labels for this data set.\"\"\"\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    @classmethod\n",
    "    def _read_tsv(cls, input_file, quotechar=None):\n",
    "        \"\"\"Reads a tab separated value file.\"\"\"\n",
    "        with tf.gfile.Open(input_file, \"r\") as f:\n",
    "            reader = csv.reader(f, delimiter=\"\\t\", quotechar=quotechar)\n",
    "            lines = []\n",
    "            for line in reader:\n",
    "                lines.append(line)\n",
    "            return lines\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ColaProcessor(DataProcessor):\n",
    "    \n",
    "    \"\"\"Processor for the CoLA data set (GLUE version).\"\"\"\n",
    "\n",
    "    def get_train_examples(self, data_dir):\n",
    "        \"\"\"See base class.\"\"\"\n",
    "        return self._create_examples(\n",
    "            self._read_tsv(os.path.join(data_dir, \"train.tsv\")), \"train\")\n",
    "\n",
    "    def get_dev_examples(self, data_dir):\n",
    "        \"\"\"See base class.\"\"\"\n",
    "        return self._create_examples(\n",
    "            self._read_tsv(os.path.join(data_dir, \"dev.tsv\")), \"dev\")\n",
    "\n",
    "    def get_test_examples(self, data_dir):\n",
    "        \"\"\"See base class.\"\"\"\n",
    "        return self._create_examples(\n",
    "            self._read_tsv(os.path.join(data_dir, \"test.tsv\")), \"test\")\n",
    "\n",
    "    def get_labels(self):\n",
    "        \"\"\"See base class.\"\"\"\n",
    "        return [\"0\", \"1\"]\n",
    "\n",
    "    def _create_examples(self, lines, set_type):\n",
    "        \"\"\"Creates examples for the training and dev sets.\"\"\"\n",
    "        examples = []\n",
    "        for (i, line) in enumerate(lines):\n",
    "\n",
    "          # Only the test set has a header\n",
    "            if set_type == \"test\" and i == 0:\n",
    "                continue\n",
    "            guid = \"%s-%s\" % (set_type, i)\n",
    "            if set_type == \"test\":\n",
    "                text_a = tokenization.convert_to_unicode(line[1])\n",
    "                label = \"0\"\n",
    "            else:\n",
    "                text_a = tokenization.convert_to_unicode(line[3])\n",
    "                label = tokenization.convert_to_unicode(line[1])\n",
    "            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))\n",
    "        return examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_single_example(ex_index, example, label_list, max_seq_length,\n",
    "                           tokenizer):\n",
    "    \"\"\"Converts a single `InputExample` into a single `InputFeatures`.\"\"\"\n",
    "\n",
    "    if isinstance(example, PaddingInputExample):\n",
    "        return InputFeatures(\n",
    "            input_ids=[0] * max_seq_length,\n",
    "            input_mask=[0] * max_seq_length,\n",
    "            segment_ids=[0] * max_seq_length,\n",
    "            label_id=0,\n",
    "            is_real_example=False)\n",
    "    label_map = {}\n",
    "    for (i, label) in enumerate(label_list):\n",
    "        \n",
    "        label_map[label] = i\n",
    "\n",
    "    tokens_a = tokenizer.tokenize(example.text_a)\n",
    "    tokens_b = None\n",
    "    if example.text_b:\n",
    "        tokens_b = tokenizer.tokenize(example.text_b)\n",
    "\n",
    "    if tokens_b:\n",
    "        # Modifies `tokens_a` and `tokens_b` in place so that the total\n",
    "        # length is less than the specified length.\n",
    "        # Account for [CLS], [SEP], [SEP] with \"- 3\"\n",
    "        _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)\n",
    "    else:\n",
    "        # Account for [CLS] and [SEP] with \"- 2\"\n",
    "        if len(tokens_a) > max_seq_length - 2:\n",
    "            tokens_a = tokens_a[0:(max_seq_length - 2)]\n",
    "    \n",
    "    \n",
    "    tokens = []\n",
    "    segment_ids = []\n",
    "    tokens.append(\"[CLS]\")\n",
    "    segment_ids.append(0)\n",
    "    for token in tokens_a:\n",
    "        tokens.append(token)\n",
    "        segment_ids.append(0)\n",
    "    tokens.append(\"[SEP]\")\n",
    "    segment_ids.append(0)\n",
    "\n",
    "    if tokens_b:\n",
    "        for token in tokens_b:\n",
    "            tokens.append(token)\n",
    "            segment_ids.append(1)\n",
    "        tokens.append(\"[SEP]\")\n",
    "        segment_ids.append(1)\n",
    "\n",
    "    input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    \n",
    "    \n",
    "    \n",
    "    input_mask = [1] * len(input_ids)\n",
    "\n",
    "  # Zero-pad up to the sequence length.\n",
    "    while len(input_ids) < max_seq_length:\n",
    "        input_ids.append(0)\n",
    "        input_mask.append(0)\n",
    "        segment_ids.append(0)\n",
    "\n",
    "    assert len(input_ids) == max_seq_length\n",
    "    assert len(input_mask) == max_seq_length\n",
    "    assert len(segment_ids) == max_seq_length\n",
    "\n",
    "    label_id = label_map[example.label]\n",
    "    if ex_index < 5:\n",
    "        tf.logging.info(\"*** Example ***\")\n",
    "        tf.logging.info(\"guid: %s\" % (example.guid))\n",
    "        tf.logging.info(\"tokens: %s\" % \" \".join([tokenization.printable_text(x) for x in tokens]))\n",
    "        tf.logging.info(\"input_ids: %s\" % \" \".join([str(x) for x in input_ids]))\n",
    "        tf.logging.info(\"input_mask: %s\" % \" \".join([str(x) for x in input_mask]))\n",
    "        tf.logging.info(\"segment_ids: %s\" % \" \".join([str(x) for x in segment_ids]))\n",
    "        tf.logging.info(\"label: %s (id = %d)\" % (example.label, label_id))\n",
    "\n",
    "    feature = InputFeatures(\n",
    "      input_ids=input_ids,\n",
    "      input_mask=input_mask,\n",
    "      segment_ids=segment_ids,\n",
    "      label_id=label_id,\n",
    "      is_real_example=True)\n",
    "    \n",
    "    \n",
    "    return feature\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
